import os
import gzip
import pybedtools
from Bio import SeqIO
from Bio.Seq import reverse_complement


assembly = 'hg38'

def read_chromosome_names(assembly):
    chromosomes = ["chr%s" % i for i in list(range(1,23)) + ['X', 'Y']]
    names = {}
    directory = "/osc-fs_home/scratch/mdehoon/Data/Genomes/"
    filename = "%s.chrom.sizes" % assembly
    path = os.path.join(directory, assembly, filename)
    handle = open(path)
    for line in handle:
        chromosome, size = line.split()
        size = int(size)
        terms = chromosome.split("_")
        if len(terms) == 1:
            assert chromosome.startswith("chr")
            if chromosome == "chrM":
                key = "MT"
            else:
                key = chromosome[3:]
                assert chromosome in chromosomes
        elif len(terms) == 2:
            assert terms[0] == "chrUn"
            key, version = terms[1].split("v")
            key = "%s.%s" % (key, version)
        elif len(terms) == 3:
            assert terms[2] in ("alt", "random")
            key, version = terms[1].split("v")
            key = "%s.%s" % (key, version)
            assert terms[0] in chromosomes
        else:
            raise Exception("Unknown chromosome %s" % chromosome)
        names[key] = chromosome
    handle.close()
    return names

def read_chromosome_sizes(assembly):
    sizes = {}
    directory = "/osc-fs_home/scratch/mdehoon/Data/Genomes/"
    filename = "%s.chrom.sizes" % assembly
    path = os.path.join(directory, assembly, filename)
    handle = open(path)
    for line in handle:
        chromosome, size = line.split()
        assert chromosome.startswith("chr")
        size = int(size)
        sizes[chromosome] = size
    handle.close()
    return sizes

def read_rfam_annotations():
    path = "/osc-fs_home/scratch/mdehoon/Data/RNAcentral/rfam_annotations.tsv.gz"
    print("Reading", path)
    stream = gzip.open(path, "rt")
    descriptions = {}
    for line in stream:
        words = line.strip().split("\t")
        assert len(words) == 9
        accession = words[0]
        description = words[8]
        if not description.startswith("Small nucleolar RNA"):
            continue
        descriptions[accession] = description
    stream.close()
    return descriptions

def read_snoDB_annotations():
    directory = "/osc-fs_home/scratch/mdehoon/Data/snoDB"
    filename = "snoDB.tsv"
    path = os.path.join(directory, filename)
    print("Reading", path)
    stream = open(path)
    line = next(stream)
    words = line.strip().split("\t")
    assert len(words) == 56
    assert words[0] == "id"
    assert words[1] == "Symbol"
    assert words[2] == "Synonym"
    assert words[3] == "Box"
    assert words[4] == "Chromosome"
    assert words[5] == "Start"
    assert words[6] == "End"
    assert words[7] == "Strand"
    assert words[8] == "Length"
    assert words[9] == "gtf"
    assert words[10] == "HGNC"
    assert words[11] == "Ensembl"
    assert words[12] == "RNA central"
    assert words[13] == "RefSeq"
    assert words[14] == "NCBI"
    assert words[15] == "Rfam"
    assert words[16] == "snoRNABase"
    assert words[17] == "snOPY"
    assert words[18] == "snoRNA Atlas"
    assert words[19] == "conservation"
    assert words[20] == "Host Gene ID"
    assert words[21] == "Host Symbol"
    assert words[22] == "Host Synonym"
    assert words[23] == "Biotype"
    assert words[24] == "Host Start"
    assert words[25] == "Host End"
    assert words[26] == "Host Strand"
    assert words[27] == "Host Function"
    assert words[28] == "target count"
    assert words[29] == "target summary"
    assert words[30] == "lncrna"
    assert words[31] == "mirna"
    assert words[32] == "ncrna"
    assert words[33] == "protein_coding"
    assert words[34] == "pseudogene"
    assert words[35] == "rrna"
    assert words[36] == "snorna"
    assert words[37] == "snrna"
    assert words[38] == "trna"
    assert words[39] == "expressed"
    assert words[40] == "breast"
    assert words[41] == "liver"
    assert words[42] == "prostate"
    assert words[43] == "ovaries"
    assert words[44] == "skeletal_muscle"
    assert words[45] == "skov_frt"
    assert words[46] == "skov_frt"
    assert words[47] == "skov_frt"
    assert words[48] == "skov_frt"
    assert words[49] == "host_breast"
    assert words[50] == "host_liver"
    assert words[51] == "host_npr036"
    assert words[52] == "host_ovn218"
    assert words[53] == "host_skeletal_muscle"
    assert words[54] == "Sequence"
    assert words[55] == "snps"
    annotations = {}
    for line in stream:
        words = line.strip().split("\t")
        accession = words[11]
        symbol = words[1]
        box = words[3]
        if accession == "":  # for pseudo-snoRNAs
            accession = words[9]  # gtf
        annotations[accession] = (box, symbol)
    stream.close()
    return annotations

def read_ensembl_snornas(rfam_descriptions, snoDB_annotations):
    print("Reading genome")
    handle = open('/osc-fs_home/scratch/mdehoon/Data/Genomes/hg38/hg38.2bit', 'rb')
    genome = SeqIO.parse(handle, 'twobit')
    chromosomes = read_chromosome_names(assembly)
    intervals = []
    transcripts = {}
    descriptions = {}
    directory = "/osc-fs_home/scratch/mdehoon/Data/Ensembl"
    numbers = [str(i) for i in list(range(1,23)) + ['X', 'Y', 'MT']] + [None]
    for number in numbers:
        if number is None:
            filename = "Homo_sapiens.GRCh38.100.nonchromosomal.dat.gz"
        else:
            filename = "Homo_sapiens.GRCh38.100.chromosome.%s.dat.gz" % number
        path = os.path.join(directory, filename)
        print("Reading", path)
        handle = gzip.open(path, 'rt')
        records = SeqIO.parse(handle, 'genbank')
        for record in records:
            sequence = str(record.seq)
            assert len(record.annotations['accessions']) == 1
            accession = record.annotations['accessions'][0]
            terms = accession.split(":")
            assert len(terms) == 6
            if number is None:
                assert terms[0] in ('chromosome', 'scaffold')
            else:
                assert terms[0] == 'chromosome'
            assert terms[1] == 'GRCh38'
            if number is not None:
                assert terms[2] == number
            chromosome = chromosomes.get(terms[2])
            start = int(terms[3]) - 1
            end = int(terms[4])
            assert int(terms[5]) == 1
            assert end - start == len(sequence)
            offset = start
            for feature in record.features:
                if feature.type == 'misc_RNA':
                    notes = feature.qualifiers['note']
                    assert len(notes) == 1
                    if notes[0] != 'snoRNA':
                        continue
                    names = feature.qualifiers['standard_name']
                    assert len(names) == 1
                    name = names[0]
                    genes = feature.qualifiers['gene']
                    assert len(genes) == 1
                    gene = genes[0]
                    RNAcentral_xref = None
                    HGNC_trans_name = None
                    for db_xref in feature.qualifiers['db_xref']:
                        db, xref = db_xref.split(":")
                        if db == "RNAcentral":
                            RNAcentral_xref = xref
                        elif db == "HGNC_trans_name":
                            HGNC_trans_name = xref
                    if RNAcentral_xref is None:
                        raise Exception("failed to find RNAcentral cross-reference for %s, %s" % (name, gene))
                    if HGNC_trans_name is not None and HGNC_trans_name.startswith("SCARNA"):
                        print("Skipping scaRNA %s, %s" % (name, gene))
                        continue
                    try:
                        box, symbol = snoDB_descriptions[gene.split(".")[0]]
                    except KeyError:
                        box, symbol = None, None
                    description = rfam_descriptions.get(xref)
                    if description is None:
                        description = symbol
                    if box is None:
                        description = "%s %s" % (gene, description)
                    else:
                        description = "%s %s box %s" % (gene, box, description)
                    descriptions[name] = description
                    start = int(feature.location.start)
                    end = int(feature.location.end)
                    strand = feature.location.strand
                    transcript = sequence[start:end]
                    if strand == 1:
                        strand = '+'
                    else:
                        assert strand == -1
                        strand = '-'
                        transcript = reverse_complement(transcript)
                    assert name not in transcripts
                    transcripts[name] = transcript
                    if chromosome is None:
                        continue
                    start += offset
                    end += offset
                    if strand == '+':
                        assert transcript == genome[chromosome].seq[start:end].upper()
                    else:
                        assert reverse_complement(transcript) == genome[chromosome][start:end].seq.upper()
                    length = end - start
                    lengths = "%s," % length
                    fields = [chromosome, start, end, name, "0", strand, start, end, "0", 1, lengths, "0,"]
                    interval = pybedtools.create_interval_from_list(fields)
                    intervals.append(interval)
        handle.close()
    intervals = pybedtools.BedTool(intervals)
    return intervals, transcripts, descriptions

def read_refseq_snorna_loci():
    path = "snoRNA.bed"
    print("Reading", path)
    lines = pybedtools.BedTool(path)
    return lines

def read_refseq_snorna_transcripts():
    path = "snoRNA.fa"
    print("Reading", path)
    records = SeqIO.parse(path, 'fasta')
    records = list(records)
    return records

sizes = read_chromosome_sizes(assembly)

rfam_descriptions = read_rfam_annotations()
snoDB_descriptions = read_snoDB_annotations()

ensembl_intervals, ensembl_sequences, ensembl_descriptions = read_ensembl_snornas(rfam_descriptions, snoDB_descriptions)

print("Total number of Ensembl snoRNAs: %d" % len(ensembl_sequences))
print("Total number of mapped Ensembl snoRNAs: %d" % len(ensembl_intervals))

refseq_intervals = read_refseq_snorna_loci()

novel_intervals = ensembl_intervals.intersect(refseq_intervals, s=True, v=True)

psl_lines = {}
for interval in novel_intervals:
    qSize = interval.end - interval.start
    blockCount = int(interval.fields[9])
    assert blockCount == 1
    matches = qSize
    misMatches = 0
    repMatches = 0
    nCount = 0
    qNumInsert = 0
    qBaseInsert = 0
    tNumInsert = 0
    tBaseInsert = 0
    strand = interval.strand
    qName = interval.name
    qStart = 0
    qEnd = qSize
    tName = interval.chrom
    tSize = sizes[tName]
    tStart = interval.start
    tEnd = interval.end
    blockSizes = "%d," % qSize
    qStarts = "0,"
    tStarts = "%d," % tStart
    fields = [qSize,
              misMatches,
              repMatches,
              nCount,
              qNumInsert,
              qBaseInsert,
              tNumInsert,
              tBaseInsert,
              strand,
              qName,
              qSize,
              qStart,
              qEnd,
              tName,
              tSize,
              tStart,
              tEnd,
              blockCount,
              blockSizes,
              qStarts,
              tStarts]
    line = "\t".join([str(field) for field in fields]) + "\n"
    psl_lines[qName] = line


records = read_refseq_snorna_transcripts()
filename = "snoRNA.fa"
print("Writing %s" % filename)
handle = open(filename, 'w')
for record in records:
    handle.write(format(record, 'fasta'))

for name in ensembl_sequences:
    if name in psl_lines:
        sequence = ensembl_sequences[name]
        description = ensembl_descriptions[name]
        if description is None:
            handle.write('>%s\n' % name)
        else:
            handle.write('>%s %s\n' % (name, description))
        handle.write('%s\n' % sequence)

handle.close()

merged_intervals = []
for interval in refseq_intervals:
    merged_intervals.append(interval)
for interval in ensembl_intervals:
    name = interval.name
    if name in psl_lines:
        merged_intervals.append(interval)

merged_intervals = pybedtools.BedTool(merged_intervals)
merged_intervals = merged_intervals.sort()

filename = "snoRNA.psl"
print("Reading %s" % filename)
handle = open(filename)
for line in handle:
    words = line.split()
    assert len(words) == 21
    name = words[9]
    assert name not in psl_lines
    psl_lines[name] = line
handle.close()

print("Writing %s" % filename)
handle = open(filename, 'w')
for interval in merged_intervals:
    name = interval.name
    line = psl_lines[name]
    handle.write(line)
handle.close()
